{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "simple_api_fine_tuning.ipynb",
      "provenance": [],
      "collapsed_sections": [],
      "toc_visible": true
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "xT4_NpW-TotE"
      },
      "source": [
        "# Welcome to `jiant`\n",
        "This notebook contains an example of fine-tuning a `roberta-base` model on the MRPC task using the simple `jiant` API."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "JlM8H-WCoh9k"
      },
      "source": [
        "# Install dependencies"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "rY-7weGtIEUX"
      },
      "source": [
        "%%capture\n",
        "!git clone https://github.com/nyu-mll/jiant.git\n",
        "# This Colab notebook already has its CUDA-runtime compatible versions of torch and torchvision installed\n",
        "!pip install -r jiant/requirements-no-torch.txt"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "j78UDhA7UMzi"
      },
      "source": [
        "# Imports"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "5hsmmr9eIJJt"
      },
      "source": [
        "import sys\n",
        "sys.path.insert(0, \"/content/jiant\")"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "UXYZHyyNGayI"
      },
      "source": [
        "import os\n",
        "\n",
        "import jiant.utils.python.io as py_io\n",
        "import jiant.proj.simple.runscript as simple_run\n",
        "import jiant.scripts.download_data.runscript as downloader"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "aow1wIIDUS4h"
      },
      "source": [
        "# Define task and model"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "AEtgbtkRHDJE"
      },
      "source": [
        "# See https://github.com/nyu-mll/jiant/blob/master/guides/tasks/supported_tasks.md for supported tasks\n",
        "TASK_NAME = \"mrpc\"\n",
        "\n",
        "# See https://huggingface.co/models for supported models\n",
        "HF_PRETRAINED_MODEL_NAME = \"roberta-base\""
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "xDjg1tRNUk9r"
      },
      "source": [
        "# Create directories for task data and experiment"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Gxy_csM9UhhA"
      },
      "source": [
        "# Remove forward slashes so RUN_NAME can be used as path\n",
        "MODEL_NAME = HF_PRETRAINED_MODEL_NAME.split(\"/\")[-1]\n",
        "RUN_NAME = f\"simple_{TASK_NAME}_{MODEL_NAME}\"\n",
        "EXP_DIR = \"/content/exp\"\n",
        "DATA_DIR = \"/content/exp/tasks\"\n",
        "\n",
        "os.makedirs(DATA_DIR, exist_ok=True)\n",
        "os.makedirs(EXP_DIR, exist_ok=True)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "8tXDQ1P2Unfa"
      },
      "source": [
        "#Download data (uses `nlp` or direct download depending on task)"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "_HCQG8fEU4CU"
      },
      "source": [
        "downloader.download_data([TASK_NAME], DATA_DIR)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ZsInlNWLU5ZU"
      },
      "source": [
        "#Run simple `jiant` pipeline (train and evaluate on MRPC)"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Po0N521IHAjj"
      },
      "source": [
        "args = simple_run.RunConfiguration(\n",
        "    run_name=RUN_NAME,\n",
        "    exp_dir=EXP_DIR,\n",
        "    data_dir=DATA_DIR,\n",
        "    hf_pretrained_model_name_or_path=HF_PRETRAINED_MODEL_NAME,\n",
        "    tasks=TASK_NAME,\n",
        "    train_batch_size=16,\n",
        "    num_train_epochs=1\n",
        ")\n",
        "simple_run.run_simple(args)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "2GLUP22PoowE"
      },
      "source": [
        "The simple API `RunConfiguration` object is saved as `simple_run_config.json`. `simple_run_config.json` can be loaded and used as inputs to repeat experiments as follows."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "ckhYG6Ijh1nC",
        "collapsed": true
      },
      "source": [
        "args = simple_run.RunConfiguration.from_json_path(os.path.join(EXP_DIR, \"runs\", RUN_NAME, \"simple_run_config.json\"))\n",
        "simple_run.run_simple(args)"
      ],
      "execution_count": null,
      "outputs": []
    }
  ]
}